{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Installing Required libraries\n",
        "Step 1.1- Please refer to the \"Technical Requirements\" section in the book for the neccessary packages to be installed. Please note that chapter has differnt Pytorch Lightning version and thus diff torch dependancies. Some functions may not work with other versions than what is tested below, so please ensure correct versions."
      ],
      "metadata": {
        "id": "PJ30oA710Pyu"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "_9KC95YlVNFL"
      },
      "outputs": [],
      "source": [
        "#Step1 Please refer to the \"Technical Requirements\" section in the book for the neccessary packages to be installed here\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "EKg5HFNHY2Dt"
      },
      "outputs": [],
      "source": [
        "#refer to book for correct version of package and import here\n",
        "!pip install torch==1.10.0 torchvision==0.11.1 torchtext==0.11.0 torchaudio==0.10.0 --quiet\n",
        "!pip install pytorch-lightning==1.5.2 --quiet\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "\n",
        "# Step 1.2 Import the neccessary packages.\n",
        " Please refer to the section, Importing the packages in the book for further details. "
      ],
      "metadata": {
        "id": "vNm-Nlqa0dcC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "from torch import nn, optim\n",
        "from torch.autograd import Variable\n",
        "import pytorch_lightning as pl\n",
        "from pytorch_lightning.callbacks import ModelCheckpoint\n",
        "from torch.utils.data import DataLoader\n",
        "from torchmetrics.functional import accuracy\n"
      ],
      "metadata": {
        "id": "ulmRq0v-stBe"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "BWKjV6ouWOra",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "6b8bc9ea-6070-4d5c-bf7c-5d91f5e9e0fb"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch version: 1.10.0+cu102\n",
            "pytorch ligthening version: 1.5.2\n"
          ]
        }
      ],
      "source": [
        "print(\"torch version:\",torch.__version__)\n",
        "print(\"pytorch ligthening version:\",pl.__version__)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "gVmSh8OSniNb"
      },
      "outputs": [],
      "source": [
        "xor_input = [Variable(torch.Tensor([0, 0])),\n",
        "           Variable(torch.Tensor([0, 1])),\n",
        "           Variable(torch.Tensor([1, 0])),\n",
        "           Variable(torch.Tensor([1, 1]))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "vhN9NqXupEuM"
      },
      "outputs": [],
      "source": [
        "xor_target = [Variable(torch.Tensor([0])),\n",
        "           Variable(torch.Tensor([1])),\n",
        "           Variable(torch.Tensor([1])),\n",
        "           Variable(torch.Tensor([0]))]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "h_xCNEDUYyud"
      },
      "outputs": [],
      "source": [
        "xor_data = list(zip(xor_input, xor_target))\n",
        "train_loader = DataLoader(xor_data, batch_size=1000)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "vvgQJLTmXCu7"
      },
      "outputs": [],
      "source": [
        "class XORModel(pl.LightningModule):\n",
        "  def __init__(self):\n",
        "\n",
        "    super(XORModel,self).__init__()\n",
        "    \n",
        "    self.input_layer = nn.Linear(2, 4)\n",
        "    self.output_layer = nn.Linear(4,1)\n",
        "\n",
        "    self.sigmoid = nn.Sigmoid()\n",
        "\n",
        "    self.loss = nn.MSELoss()\n",
        "\n",
        "  def forward(self, input):\n",
        "    #print(\"INPUT:\", input.shape)\n",
        "    x = self.input_layer(input)\n",
        "    #print(\"FIRST:\", x.shape)\n",
        "    x = self.sigmoid(x)\n",
        "    #print(\"SECOND:\", x.shape)\n",
        "    output = self.output_layer(x)\n",
        "    #print(\"THIRD:\", output.shape)\n",
        "    return output\n",
        "\n",
        "  def configure_optimizers(self):\n",
        "    params = self.parameters()\n",
        "    optimizer = optim.Adam(params=params, lr = 0.01)\n",
        "    return optimizer\n",
        "\n",
        "  def training_step(self, batch, batch_idx):\n",
        "    xor_input, xor_target = batch\n",
        "    #print(\"XOR INPUT:\", xor_input.shape)\n",
        "    #print(\"XOR TARGET:\", xor_target.shape)\n",
        "    outputs = self(xor_input) \n",
        "    #print(\"XOR OUTPUT:\", outputs.shape)\n",
        "    loss = self.loss(outputs, xor_target)\n",
        "    return loss "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "TVaNoYCDbG13",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 364,
          "referenced_widgets": [
            "3c02f69e0bdd4bf8a6572674b5bbf3eb",
            "208c04859c6748e7b81ddcd04b3b136a",
            "2ad4b9a7a26b44aaa68a71e2dc403d45",
            "d235b3cd954d4d12ae885ae96c0540e3",
            "8ba1363a044c4dde84f48ae60f1feaa3",
            "0b45ac22fb0f4f30bf03e4a59e5f0540",
            "5949c27073984f08aa382620ba06c5af",
            "6568e9240cbc4a27b7581639710f06c7",
            "411b8e472a5e4f49a7eadc5573444183",
            "2aa78a41b2a34395ba3a7d4a39062944",
            "318b02b7b80041dca53888677abf81c4"
          ]
        },
        "outputId": "44ec11f9-20bb-4f76-a245-089ba7dc9d2e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "GPU available: False, used: False\n",
            "TPU available: False, using: 0 TPU cores\n",
            "IPU available: False, using: 0 IPUs\n",
            "\n",
            "  | Name         | Type    | Params\n",
            "-----------------------------------------\n",
            "0 | input_layer  | Linear  | 12    \n",
            "1 | output_layer | Linear  | 5     \n",
            "2 | sigmoid      | Sigmoid | 0     \n",
            "3 | loss         | MSELoss | 0     \n",
            "-----------------------------------------\n",
            "17        Trainable params\n",
            "0         Non-trainable params\n",
            "17        Total params\n",
            "0.000     Total estimated model params size (MB)\n",
            "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/data_loading.py:407: UserWarning: The number of training samples (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
            "  f\"The number of training samples ({self.num_training_batches}) is smaller than the logging interval\"\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Training: 0it [00:00, ?it/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3c02f69e0bdd4bf8a6572674b5bbf3eb"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from pytorch_lightning.utilities.types import TRAIN_DATALOADERS\n",
        "checkpoint_callback = ModelCheckpoint()\n",
        "model = XORModel()\n",
        "\n",
        "trainer = pl.Trainer(max_epochs=500, callbacks=[checkpoint_callback])\n",
        "\n",
        "trainer.fit(model, train_dataloaders=train_loader)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "I8mupvitfno4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "11baca15-4874-4d11-c534-6df258e6471b"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[0m\u001b[01;34mversion_0\u001b[0m/  \u001b[01;34mversion_1\u001b[0m/\n"
          ]
        }
      ],
      "source": [
        "ls lightning_logs/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "8hv9Gw6x_OvP",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "247f83d2-7465-4e5d-add9-a1506623ddf1"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "lightning_logs/version_0/:\n",
            "\u001b[0m\u001b[01;34mcheckpoints\u001b[0m/\n",
            "events.out.tfevents.1651698347.3efc55c883d6.60.0\n",
            "events.out.tfevents.1651698361.3efc55c883d6.60.1\n",
            "hparams.yaml\n",
            "\n",
            "lightning_logs/version_1/:\n",
            "\u001b[01;34mcheckpoints\u001b[0m/  events.out.tfevents.1651700398.3efc55c883d6.609.0  hparams.yaml\n"
          ]
        }
      ],
      "source": [
        "ls lightning_logs/*/"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(checkpoint_callback.best_model_path)"
      ],
      "metadata": {
        "id": "kWCYIw90-wpR",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "4bfb6071-616f-4757-b593-a98fd27fa171"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/lightning_logs/version_1/checkpoints/epoch=499-step=499.ckpt\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "8WHq0ZAQe0zK",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "8d1dcd81-f754-42b5-e835-320596e8de2d"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/lightning_logs/version_1/checkpoints/epoch=499-step=499.ckpt\n",
            "[0, 0] 0\n",
            "[0, 1] 1\n",
            "[1, 0] 1\n",
            "[1, 1] 0\n"
          ]
        }
      ],
      "source": [
        "print(checkpoint_callback.best_model_path)\n",
        "train_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)\n",
        "test = torch.utils.data.DataLoader(xor_input, batch_size=1)\n",
        "for val in xor_input:\n",
        "  _ = train_model(val)\n",
        "  print([int(val[0]),int(val[1])], int(_.round()))"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "print(checkpoint_callback.best_model_path)\n",
        "train_model = model.load_from_checkpoint(checkpoint_callback.best_model_path)\n",
        "total_accuracy = []\n",
        "for xor_input, xor_target in train_loader:\n",
        "  for i in range(100):\n",
        "    output_tensor = train_model(xor_input)\n",
        "    test_accuracy = accuracy(output_tensor, xor_target.int())\n",
        "    total_accuracy.append(test_accuracy)\n",
        "total_accuracy = torch.mean(torch.stack(total_accuracy))\n",
        "print(\"TOTAL ACCURACY FOR 100 ITERATIONS: \", total_accuracy.item())"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "RkHHZThZZphY",
        "outputId": "4ace4418-ee73-45dd-8336-1b447f70153d"
      },
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "/content/lightning_logs/version_1/checkpoints/epoch=499-step=499.ckpt\n",
            "TOTAL ACCURACY FOR 100 ITERATIONS:  1.0\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "VqOXMSBMn1G9"
      },
      "execution_count": 15,
      "outputs": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "Github Copy of Perceptron_model_XOR.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.6"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "3c02f69e0bdd4bf8a6572674b5bbf3eb": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_208c04859c6748e7b81ddcd04b3b136a",
              "IPY_MODEL_2ad4b9a7a26b44aaa68a71e2dc403d45",
              "IPY_MODEL_d235b3cd954d4d12ae885ae96c0540e3"
            ],
            "layout": "IPY_MODEL_8ba1363a044c4dde84f48ae60f1feaa3"
          }
        },
        "208c04859c6748e7b81ddcd04b3b136a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_0b45ac22fb0f4f30bf03e4a59e5f0540",
            "placeholder": "​",
            "style": "IPY_MODEL_5949c27073984f08aa382620ba06c5af",
            "value": "Epoch 499: 100%"
          }
        },
        "2ad4b9a7a26b44aaa68a71e2dc403d45": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_6568e9240cbc4a27b7581639710f06c7",
            "max": 1,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_411b8e472a5e4f49a7eadc5573444183",
            "value": 1
          }
        },
        "d235b3cd954d4d12ae885ae96c0540e3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_2aa78a41b2a34395ba3a7d4a39062944",
            "placeholder": "​",
            "style": "IPY_MODEL_318b02b7b80041dca53888677abf81c4",
            "value": " 1/1 [00:00&lt;00:00, 43.91it/s, loss=9.09e-14, v_num=1]"
          }
        },
        "8ba1363a044c4dde84f48ae60f1feaa3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": "inline-flex",
            "flex": null,
            "flex_flow": "row wrap",
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": "100%"
          }
        },
        "0b45ac22fb0f4f30bf03e4a59e5f0540": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "5949c27073984f08aa382620ba06c5af": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "6568e9240cbc4a27b7581639710f06c7": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": "2",
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "411b8e472a5e4f49a7eadc5573444183": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "2aa78a41b2a34395ba3a7d4a39062944": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "318b02b7b80041dca53888677abf81c4": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}